#!/usr/bin/env python3
from typing import Callable
import numpy as np
import torch

import rpi
from rpi import logger
from rpi.agents.base import Agent


@torch.no_grad()
def eval_fn(make_env: Callable, agent: Agent, max_episode_len, num_episodes: int = 20, save_video_num_ep: int = 0, verbose: bool = False):
    """
    Args:
        - save_video_num_ep: number of episodes to save the frames
    """
    import wandb
    from rpi.helpers.env import rollout_single_ep

    env = make_env(test=True)
    returns = []
    ep_lens = []
    frames = []

    def policy(obs: np.ndarray):
        return agent.act(obs, mode=True)

    if verbose:
        logger.info('Running evaluation...')

    with rpi.helpers.evaluating(agent.pi):
        for ep_idx in range(num_episodes):
            ep = rollout_single_ep(env, policy, max_episode_len, save_video=(ep_idx < save_video_num_ep))
            returns.append(sum(transition['reward'] for transition in ep))
            ep_lens.append(len(ep))

            if verbose:
                logger.info(f'eval {ep_idx + 1} / {num_episodes} -- return: {returns[-1]}')

            if 'frame' in ep[0]:
                frames += [transition['frame'] for transition in ep]

    out = {'returns_mean': np.array(returns).mean(),
           'returns_std': np.array(returns).std(),
           'returns': wandb.Histogram(returns),
           'ep_lens_mean': np.array(ep_lens).mean(),
           'ep_lens': wandb.Histogram(ep_lens),
           '_returns': returns}
    if len(frames) > 0:
        out['frames'] = frames
    return out


def evaluate(env, agent, num_episodes=30, save_video_hook=None, render_frames=True):
    import wandb
    import torch
    import numpy as np
    from rpi import logger
    from rpi.helpers.data import to_torch

    logger.info('Running evaluation...')
    episode_returns = []
    frames = []
    with torch.no_grad():
        for i in range(num_episodes):
            obs = env.reset()
            done = False
            episode_return = 0
            while not done:
                obs, reward, done, info = env.step(agent.act(to_torch(obs).unsqueeze(0)).squeeze(0))
                episode_return += reward
                if i == 0 and render_frames:
                    frames.append(env.render(mode='rgb_array'))

            episode_returns.append(episode_return)

    if save_video_hook and render_frames:
        save_video_hook(np.asarray(frames, dtype=np.uint8))
    return episode_returns
